Skip to content

Conversation

@NickLucche
Copy link
Collaborator

@NickLucche NickLucche commented Mar 26, 2025

It appears recompilation was due to a misconfiguration when pre-compiling the num_reqs.
Added tests to CI to confirm this upstream.

@NickLucche NickLucche marked this pull request as ready for review March 26, 2025 14:44
@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@robertgshaw2-redhat
Copy link
Collaborator

nice job nicolo!

Comment on lines 91 to 93
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MIN_NUM_SEQS is 8, should this be rounding up to the nearest number divisible by MIN_NUM_SEQS?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not necessarily see in capture_model, max_num_reqs still gets compiled anyway due to how padding with upper limit is implemented.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw I'm ok with forcing everything to be nicely divisible by MIN_NUM_SEQS; just I remember max_num_seqs used to be padded and then it was changed for reasons

@yarongmu-google
Copy link
Contributor

tests/tpu/test_compilation.py failed it seems w/ this error:

Processed prompts: 100%|████████████████████████████████████████████████████| 1/1 [00:45<00:00, 46.00s/it, est. speed input: 109.60 toks/s, output: 0.11 toks/s]

  | WARNING 03-26 16:58:08 [parallel_state.py:1093] torch._C._host_emptyCache() only available in Pytorch >=2.5
  | FAILEDWARNING 03-26 16:58:08 [parallel_state.py:1093] torch._C._host_emptyCache() only available in Pytorch >=2.5

Source: https://buildkite.com/vllm/fastcheck/builds/18374#0195d341-8f23-4eb8-bd34-998e6055b2ff

@bvrockwell
Copy link
Contributor

tests/tpu/test_compilation.py failed it seems w/ this error:

Processed prompts: 100%|████████████████████████████████████████████████████| 1/1 [00:45<00:00, 46.00s/it, est. speed input: 109.60 toks/s, output: 0.11 toks/s]

  | WARNING 03-26 16:58:08 [parallel_state.py:1093] torch._C._host_emptyCache() only available in Pytorch >=2.5   | FAILEDWARNING 03-26 16:58:08 [parallel_state.py:1093] torch._C._host_emptyCache() only available in Pytorch >=2.5

Source: https://buildkite.com/vllm/fastcheck/builds/18374#0195d341-8f23-4eb8-bd34-998e6055b2ff

@NickLucche would we be able to set --enforce-eager=False on all tests except test_compilation.py please to capture whether recompilation is resolved fully?

@robertgshaw2-redhat
Copy link
Collaborator

tests/tpu/test_compilation.py failed it seems w/ this error:

Processed prompts: 100%|████████████████████████████████████████████████████| 1/1 [00:45<00:00, 46.00s/it, est. speed input: 109.60 toks/s, output: 0.11 toks/s]

  | WARNING 03-26 16:58:08 [parallel_state.py:1093] torch._C._host_emptyCache() only available in Pytorch >=2.5   | FAILEDWARNING 03-26 16:58:08 [parallel_state.py:1093] torch._C._host_emptyCache() only available in Pytorch >=2.5
Source: https://buildkite.com/vllm/fastcheck/builds/18374#0195d341-8f23-4eb8-bd34-998e6055b2ff

@NickLucche would we be able to set --enforce-eager=False on all tests except test_compilation.py please to capture whether recompilation is resolved fully?

It’s false by default.

@bvrockwell
Copy link
Contributor

tests/tpu/test_compilation.py failed it seems w/ this error:

Processed prompts: 100%|████████████████████████████████████████████████████| 1/1 [00:45<00:00, 46.00s/it, est. speed input: 109.60 toks/s, output: 0.11 toks/s]

  | WARNING 03-26 16:58:08 [parallel_state.py:1093] torch._C._host_emptyCache() only available in Pytorch >=2.5   | FAILEDWARNING 03-26 16:58:08 [parallel_state.py:1093] torch._C._host_emptyCache() only available in Pytorch >=2.5
Source: https://buildkite.com/vllm/fastcheck/builds/18374#0195d341-8f23-4eb8-bd34-998e6055b2ff

@NickLucche would we be able to set --enforce-eager=False on all tests except test_compilation.py please to capture whether recompilation is resolved fully?

It’s false by default.

https://github.com/NickLucche/vllm/blob/tpu-fix-recompilation/tests/v1/tpu/test_basic.py#L34

Am I reading this wrong, where should I be checking this otherwise?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NickLucche thank you for fixing this! qq: is this change actually fix the recompilation issue?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also have such confusion. IMO, the code fix the recompilation issue when max_num_reqs is not power of 2. But in our tests, it's already power of 2.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to also enable the following in test_basic.py? enforce_eager=False

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so. We should use the default value of enforce_eager (which is False) in most cases.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed, it is set to True in test_basic.py

@mergify mergify bot added the tpu Related to Google TPUs label Mar 27, 2025
@mergify
Copy link

mergify bot commented Mar 27, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @NickLucche.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 27, 2025
Copy link
Collaborator

@yaochengji yaochengji left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NickLucche Thanks for fixing this. Left a few comments.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we use torch.compile for this method?

Usually it's easier to know the boundary of TPU computation and avoid recompilation if we wrap it inside a torch.compile

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

once main is stable we'll turn that on. I'd like to do that in a separate PR. Last time around compilation got slower, just wanted to be cautious.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did an experiment base on you branch, I cleaned the xla cache each time before the execution.
without torch.compile: sampler pre-compilation time is 35.79 [secs]
with torch.compile: sampler pre-compilation time is 35.51 [secs]

The compilation time difference is negligible. In the meantime, torch.compile can speed up the execution because the guard check of torch.compile is usually faster than torch/xla's graph trace.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

happy to add it back then! Just would like to merge this PR first if you don't mind, it still fixes the pre-compiliation when max_num_reqs is not a power of 2

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add it here in this PR? Just one line code change.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so. We should use the default value of enforce_eager (which is False) in most cases.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we'd better add xm.mark_step() before this line unless we use torch.compile.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't it redundant? A sync to CPU will still cause the graph to be flushed and executed

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TL;DR, xm.mark_step computes all the pending tensor, but B.cpu() only computes the exact tensor B.

E.g. say we have the code

A = ...
B = op(A)

The graph output generated by xm.mark_step() and B.cpu() are different.

For xm.mark_step(), we will get both A and B as outputs.
For B.cpu(), only B is the output.

Then if we have another xm.mark_step() later, as A's result is not returned in the previous computation, we have to compute A again.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining this is really helpful.
I don't think this is the case as we only need the sampled output tokens from the sampler step, but I could also add it for completeness.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although we don't intend to use other tensors, a later xm.mark_step will still try to get the results of them. Then we have redundant computation.

BTW, we have an implicit xm.mark_step when using torch.compile. What's one reason I recommend using torch.compile when possible.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And when we use torch.compile, we don't need so many xm.mark_step.

torch.compile is much more easier for pytorch developer to understand.

@NickLucche
Copy link
Collaborator Author

@NickLucche thank you for fixing this! qq: is this change actually fix the recompilation issue?

So I've added the test you mentioned in the previous PR, test_sampler.py with enforce_eager=False and you can run it with VLLM_XLA_CHECK_RECOMPILATION=1 .
If you did the same thing prior to this PR it would show a graph compilation was detected at runtime.

Mentioned it in this thread too #15309.

Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
@NickLucche NickLucche force-pushed the tpu-fix-recompilation branch from 957f043 to 2e61e31 Compare March 27, 2025 13:03
@mergify mergify bot removed the needs-rebase label Mar 27, 2025
@NickLucche
Copy link
Collaborator Author

NickLucche commented Mar 27, 2025

Also another test worth running that I had mentioned on slack but not here is a benchmark with the check recompilation on.

VLLM_XLA_CHECK_RECOMPILATION=1 VLLM_USE_V1=1 vllm serve meta-llama/Llama-3.1-8B-Instruct \
 --disable-log-requests \
 --port 8004 \
 --gpu-memory-utilization 0.95 \
 --max-num-seqs 512 \
 --max-num-batched-tokens 512 \
 --tensor-parallel-size 1 \
 --max-model-len 2048 > "$VLLM_LOG" 2>&1 &

Feel free to ping me if you spot any case where the check fails though!

@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 27, 2025
@mgoin mgoin enabled auto-merge (squash) March 27, 2025 14:11
@DarkLight1337
Copy link
Member

Can you push a dummy commit to retry the doc build?

Signed-off-by: NickLucche <nlucches@redhat.com>
@NickLucche
Copy link
Collaborator Author

thanks for the ping @DarkLight1337

Copy link
Contributor

@bvrockwell bvrockwell left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for looking into this @NickLucche ! @yaochengji recommended this yesterday and it's a good idea: could we please set enforce-eager for all tests (like test_basic.py, where this is still set to True currently) except for test_compilation.py. This ensures all tests are checking for recompilation and so nothing slips inadvertently in the future.

@NickLucche
Copy link
Collaborator Author

Sure, it's a bit out of scope for this PR I'd suggest we open another one.
Here I only meant to fix the issues related to max_num_reqs when compiling on main + adding a test (test 7) for it.

Copy link
Collaborator

@yaochengji yaochengji left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your fix again. Could you address my comments in this PR?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although we don't intend to use other tensors, a later xm.mark_step will still try to get the results of them. Then we have redundant computation.

BTW, we have an implicit xm.mark_step when using torch.compile. What's one reason I recommend using torch.compile when possible.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And when we use torch.compile, we don't need so many xm.mark_step.

torch.compile is much more easier for pytorch developer to understand.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add it here in this PR? Just one line code change.

@mgoin mgoin merged commit 4098b72 into vllm-project:main Mar 27, 2025
32 checks passed
Alex4210987 pushed a commit to LeiWang1999/vllm-bitblas that referenced this pull request Apr 5, 2025
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: xinyuxiao <xinyuxiao2024@gmail.com>
lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: Louis Ulmer <ulmerlouis@gmail.com>
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Apr 29, 2025
Signed-off-by: NickLucche <nlucches@redhat.com>
shreyankg pushed a commit to shreyankg/vllm that referenced this pull request May 3, 2025
Signed-off-by: NickLucche <nlucches@redhat.com>
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build ready ONLY add when PR is ready to merge/full CI is needed tpu Related to Google TPUs v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants